Compare DMS to natural sequence evolution¶
In [1]:
# this cell is tagged parameters for papermill parameterization
dms_summary_csv = None
growth_rates_csv = None
pango_consensus_seqs_json = None
starting_clades = None
dms_clade = None
n_random = None
exclude_clades = None
pango_dms_phenotypes_csv = None
pango_by_date_html = None
pango_affinity_vs_escape_html = None
pango_dms_vs_growth_regression_html = None
pango_dms_vs_growth_regression_by_domain_html = None
pango_dms_vs_growth_corr_html = None
pango_dms_vs_growth_corr_by_domain_html = None
exclude_clades_with_muts = None
In [2]:
# Parameters
starting_clades = ["BA.2", "BA.5", "XBB"]
dms_clade = "XBB.1.5"
dms_summary_csv = "results/summaries/summary.csv"
growth_rates_csv = "MultinomialLogisticGrowth/model_fits/rates.csv"
pango_consensus_seqs_json = (
"results/compare_natural/pango-consensus-sequences_summary.json"
)
pango_dms_phenotypes_csv = "results/compare_natural/pango_dms_phenotypes.csv"
pango_by_date_html = "results/compare_natural/pango_dms_phenotypes_by_date.html"
pango_affinity_vs_escape_html = "results/compare_natural/pango_affinity_vs_escape.html"
pango_dms_vs_growth_regression_html = (
"results/compare_natural/pango_dms_vs_growth_regression.html"
)
pango_dms_vs_growth_regression_by_domain_html = (
"results/compare_natural/pango_dms_vs_growth_regression_by_domain.html"
)
pango_dms_vs_growth_corr_html = "results/compare_natural/pango_dms_vs_growth_corr.html"
pango_dms_vs_growth_corr_by_domain_html = (
"results/compare_natural/pango_dms_vs_growth_corr_by_domain.html"
)
n_random = 200
exclude_clades = []
exclude_clades_with_muts = []
In [3]:
import collections
import itertools
import json
import math
import re
import altair as alt
import numpy
import pandas as pd
import polyclonal.plot
import scipy.stats
import statsmodels.api
_ = alt.data_transformers.disable_max_rows()
Read Pango clades and mutations¶
In [4]:
with open(pango_consensus_seqs_json) as f:
pango_clades = json.load(f)
def n_child_clades(c):
"""Get number of children clades of a Pango clade."""
direct_children = pango_clades[c]["children"]
return len(direct_children) + sum([n_child_clades(c_child) for c_child in direct_children])
def build_records(c, recs):
"""Build records of Pango clade information."""
if c in recs["clade"]:
return
recs["clade"].append(c)
recs["n_child_clades"].append(n_child_clades(c))
recs["date"].append(pango_clades[c]["designationDate"])
recs["muts_from_ref"].append(
[
mut.split(":")[1]
for field in ["aaSubstitutions", "aaDeletions"]
for mut in pango_clades[c][field]
if mut.startswith("S:")
]
)
for c_child in pango_clades[c]["children"]:
build_records(c_child, recs)
records = collections.defaultdict(list)
for starting_clade in starting_clades:
build_records(starting_clade, records)
pango_df = pd.DataFrame(records).query("clade not in @exclude_clades")
dms_clade_mutations_from_ref = pango_df.set_index("clade").at[
dms_clade, "muts_from_ref"
]
def mutations_from(muts, from_muts):
"""Get mutations from another sequence."""
new_muts = set(muts).symmetric_difference(from_muts)
assert all(re.fullmatch("[A-Z\-]\d+[A-Z\-]", m) for m in new_muts)
new_muts_d = collections.defaultdict(list)
for m in new_muts:
new_muts_d[int(m[1: -1])].append(m)
new_muts_list = []
for _, ms in sorted(new_muts_d.items()):
if len(ms) == 1:
m = ms[0]
if m in muts:
new_muts_list.append(m)
else:
assert m in from_muts
new_muts_list.append(m[-1] + m[1: -1] + m[0])
else:
m, from_m = ms
if m not in muts:
from_m, m = m, from_m
assert m in muts and from_m in from_muts
new_muts_list.append(from_m[-1] + m[1: ])
return new_muts_list
pango_df = (
pango_df
.assign(
muts_from_dms_clade=lambda x: x["muts_from_ref"].apply(
mutations_from, args=(dms_clade_mutations_from_ref,),
),
date=lambda x: pd.to_datetime(x["date"]),
)
.drop(columns="muts_from_ref")
.sort_values("date")
.reset_index(drop=True)
)
for mut in exclude_clades_with_muts:
pango_df = pango_df[pango_df["muts_from_dms_clade"].map(lambda ms: mut not in ms)]
pango_df
Out[4]:
| clade | n_child_clades | date | muts_from_dms_clade | |
|---|---|---|---|---|
| 0 | BA.2 | 384 | 2021-12-07 | [A83V, -144Y, Q146H, E183Q, E213G, V252G, H339... |
| 1 | BA.2.1 | 0 | 2022-02-25 | [A83V, -144Y, Q146H, E183Q, E213G, V252G, H339... |
| 2 | BA.2.2 | 1 | 2022-02-25 | [A83V, -144Y, Q146H, E183Q, E213G, V252G, H339... |
| 3 | BA.2.3 | 52 | 2022-02-25 | [A83V, -144Y, Q146H, E183Q, E213G, V252G, H339... |
| 4 | BA.2.7 | 0 | 2022-03-25 | [A83V, -144Y, Q146H, E183Q, E213G, V252G, H339... |
| ... | ... | ... | ... | ... |
| 1575 | GJ.1.2.7 | 0 | 2023-10-04 | [K182N, V252G, D253G, K478R, P521S] |
| 1576 | GJ.1.2.8 | 0 | 2023-10-04 | [K182N, V252G, D253G, P521S, T747I] |
| 1577 | GJ.5 | 1 | 2023-10-04 | [K182N, V252G, D253G, K478R, P521S] |
| 1578 | GJ.5.1 | 0 | 2023-10-04 | [K182N, V252G, D253G, S255F, K478R, P521S] |
| 1579 | JD.2 | 0 | 2023-10-04 | [] |
1580 rows × 4 columns
Assign DMS phenotypes to Pango clades¶
First define function that assigns DMS phenotypes to mutations:
In [5]:
# read the DMS data
dms_summary = pd.read_csv(dms_summary_csv).rename(
columns={
"spike mediated entry": "cell entry",
"human sera escape": "sera escape",
}
)
# specify DMS phenotypes of interest
phenotypes = [
"sera escape",
"ACE2 affinity",
"cell entry",
]
assert set(phenotypes).issubset(dms_summary.columns)
phenotype_colors = {
"sera escape": "red",
"ACE2 affinity": "blue",
"cell entry": "purple",
}
assert set(phenotypes) == set(phenotype_colors)
# dict that maps site to wildtype in DMS
dms_wt = dms_summary.set_index("site")["wildtype"].to_dict()
# dict that maps site to region in DMS
site_to_region = dms_summary.set_index("site")["region"].to_dict()
def mut_dms(m, dms_data):
"""Get DMS phenotypes for a mutation."""
null_d = {k: pd.NA for k in phenotypes}
if pd.isnull(m) or int(m[1: -1]) not in dms_wt:
d = null_d
d["is_RBD"] = pd.NA
else:
parent = m[0]
site = int(m[1: -1])
mut = m[-1]
wt = dms_wt[site]
if parent == wt:
try:
d = dms_data[(site, parent, mut)]
except KeyError:
d = null_d
elif mut == wt:
try:
d = {k: -v for (k, v) in dms_data[(site, mut, parent)].items()}
except KeyError:
d = null_d
else:
try:
parent_d = dms_data[(site, wt, parent)]
mut_d = dms_data[(site, wt, mut)]
d = {p: mut_d[p] - parent_d[p] for p in phenotypes}
except KeyError:
d = null_d
d["is_RBD"] = (site_to_region[site] == "RBD")
assert list(d) == phenotypes + ["is_RBD"]
return d
Now assign phenotypes to pango clades. We do this both using the actual DMS data and randomizing the DMS data among measured mutations:
In [6]:
def get_pango_dms_df(dms_data_dict):
"""Given dict mapping mutations to DMS data, get data frame of values for Pango clades."""
pango_dms_df = (
pango_df
# put one mutation in each column
.explode("muts_from_dms_clade")
.rename(columns={"muts_from_dms_clade": "mutation"})
# to add multiple columns: https://stackoverflow.com/a/46814360
.apply(
lambda cols: pd.concat([cols, pd.Series(mut_dms(cols["mutation"], dms_data_dict))]),
axis=1,
)
.melt(
id_vars=["clade", "date", "n_child_clades", "mutation", "is_RBD"],
value_vars=phenotypes,
var_name="DMS_phenotype",
value_name="mutation_effect",
)
.assign(
muts_from_dms_clade=lambda x: x.groupby(["clade", "DMS_phenotype"])["mutation"].transform(
lambda ms: "; ".join([m for m in ms if not pd.isnull(m)])
),
mutation_missing=lambda x: x["mutation"].where(
x["mutation_effect"].isnull() & x["mutation"].notnull(),
pd.NA,
),
muts_from_dms_clade_missing_data=lambda x: (
x.groupby(["clade", "DMS_phenotype"])["mutation_missing"]
.transform(lambda ms: "; ".join([m for m in ms if not pd.isnull(m)]))
),
mutation_effect=lambda x: x["mutation_effect"].fillna(0),
is_RBD=lambda x: x["is_RBD"].fillna(False),
mutation_effect_RBD=lambda x: x["mutation_effect"] * x["is_RBD"].astype(int),
mutation_effect_nonRBD=lambda x: x["mutation_effect"] * (~x["is_RBD"]).astype(int),
)
.groupby(
[
"clade",
"date",
"n_child_clades",
"muts_from_dms_clade",
"muts_from_dms_clade_missing_data",
"DMS_phenotype",
],
as_index=False,
)
.aggregate(
phenotype=pd.NamedAgg("mutation_effect", "sum"),
phenotype_RBD_only=pd.NamedAgg("mutation_effect_RBD", "sum"),
phenotype_nonRBD_only=pd.NamedAgg("mutation_effect_nonRBD", "sum"),
)
.rename(
columns={
"muts_from_dms_clade": f"muts_from_{dms_clade}",
"muts_from_dms_clade_missing_data": f"muts_from_{dms_clade}_missing_data",
},
)
.sort_values(["date", "DMS_phenotype"])
.reset_index(drop=True)
)
assert set(pango_df["clade"]) == set(pango_dms_df["clade"])
assert numpy.allclose(
pango_dms_df["phenotype"],
pango_dms_df["phenotype_RBD_only"] + pango_dms_df["phenotype_nonRBD_only"]
)
return pango_dms_df
# First, get the actual DMS data mapped to phenotype
dms_data_dict_actual = (
dms_summary
.set_index(["site", "wildtype", "mutant"])
[phenotypes]
.to_dict(orient="index")
)
pango_dms_df = get_pango_dms_df(dms_data_dict_actual)
print(f"Saving Pango DMS phenotypes to {pango_dms_phenotypes_csv}")
pango_dms_df.to_csv(pango_dms_phenotypes_csv, float_format="%.4f", index=False)
# Now get the randomized DMS data mapped to phenotype
pango_dms_dfs_rand = []
numpy.random.seed(0)
for irandom in range(1, n_random + 1):
# randomize the non-null DMS data for each phenotype
dms_summary_rand = dms_summary.copy()
for phenotype in phenotypes:
dms_summary_rand = dms_summary_rand.assign(
**{phenotype: lambda x: numpy.random.permutation(x[phenotype].values)}
)
dms_data_dict_rand = (
dms_summary_rand
.set_index(["site", "wildtype", "mutant"])
[phenotypes]
.to_dict(orient="index")
)
pango_dms_dfs_rand.append(get_pango_dms_df(dms_data_dict_rand).assign(randomize=irandom))
# all randomizations concatenated
pango_dms_df_rand = pd.concat(pango_dms_dfs_rand)
Saving Pango DMS phenotypes to results/compare_natural/pango_dms_phenotypes.csv
Plot phenotypes of Pango clades¶
Plot phenotypes of Pango clades versus their designation dates:
In [7]:
region_cols = {
"phenotype": "full spike",
"phenotype_RBD_only": "RBD only",
"phenotype_nonRBD_only": "non-RBD only",
}
pango_chart_df = (
pango_dms_df
.melt(
id_vars=[c for c in pango_dms_df if c not in region_cols],
value_vars=region_cols,
var_name="spike_region",
value_name="phenotype value",
)
.assign(
spike_region=lambda x: x["spike_region"].map(region_cols),
)
.rename(columns={f"muts_from_{dms_clade}_missing_data": "muts_missing_data"})
)
# columns cannot have "." in them for Altair
col_renames = {c: c.replace(".", "_") for c in pango_chart_df.columns if "." in c}
col_renames_rev = {v: k for (k, v) in col_renames.items()}
pango_chart_df = pango_chart_df.rename(columns=col_renames)
clade_selection = alt.selection_point(fields=["clade"], on="mouseover", empty=False)
base_pango_chart = (
alt.Chart(pango_chart_df)
.encode(
tooltip=[
alt.Tooltip(c, title=col_renames_rev[c] if c in col_renames_rev else c)
for c in pango_chart_df.columns
],
opacity=alt.condition(clade_selection, alt.value(1), alt.value(0.35)),
size=alt.condition(clade_selection, alt.value(60), alt.value(40)),
strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0)),
color=alt.Color(
"DMS_phenotype",
legend=None,
scale=alt.Scale(
range=list(phenotype_colors.values()),
domain=list(phenotype_colors.keys()),
),
),
)
.mark_circle(stroke="black")
.properties(width=300, height=125)
)
phenotype_pango_charts = []
for phenotype in phenotypes:
first_row = (phenotype == phenotypes[0])
last_row = (phenotype == phenotypes[-1])
phenotype_pango_charts.append(
base_pango_chart
.transform_filter(alt.datum["DMS_phenotype"] == phenotype)
.encode(
x=alt.X(
"date",
title="designation date of clade" if last_row else None,
axis=(
alt.Axis(titleFontSize=12, labelOverlap=True, format="%b-%Y", labelAngle=-90)
if last_row
else None
),
scale=alt.Scale(nice=False, padding=3),
),
y=alt.Y(
"phenotype value",
title=phenotype,
axis=alt.Axis(titleFontSize=12),
scale=alt.Scale(nice=False, padding=3),
),
column=alt.Column(
"spike_region",
sort=list(region_cols),
title=None,
header=(
alt.Header(labelFontSize=12, labelFontStyle="bold", labelPadding=4)
if first_row
else None
),
spacing=4,
),
)
)
pango_chart = (
alt.vconcat(*phenotype_pango_charts, spacing=4)
.configure_axis(grid=False)
.add_params(clade_selection)
.properties(
title=alt.TitleParams(
f"DMS predicted phenotypes of Pango clades descended from {', '.join(starting_clades)}",
anchor="middle",
fontSize=16,
dy=-5,
),
)
)
print(f"Saving chart to {pango_by_date_html}")
pango_chart.save(pango_by_date_html)
pango_chart
Saving chart to results/compare_natural/pango_dms_phenotypes_by_date.html
Out[7]:
Pango clade affinity versus escape scatter plot¶
In [8]:
pango_scatter_df = (
pango_dms_df
.pivot_table(
index=[
c
for c in pango_dms_df
if c not in {"DMS_phenotype", "phenotype", "phenotype_RBD_only", "phenotype_nonRBD_only"}
],
values="phenotype",
columns="DMS_phenotype",
)
.reset_index()
.rename(columns={f"muts_from_{dms_clade}_missing_data": "muts_missing_data"})
.rename(columns=col_renames)
)
pango_scatter_df
pango_scatter_chart = (
alt.Chart(pango_scatter_df)
.encode(
x=alt.X(
"ACE2 affinity",
axis=alt.Axis(titleFontSize=12),
scale=alt.Scale(nice=False, padding=5),
),
y=alt.Y(
"sera escape",
axis=alt.Axis(titleFontSize=12),
scale=alt.Scale(nice=False, padding=5),
),
tooltip=[
alt.Tooltip(c, title=col_renames_rev[c] if c in col_renames_rev else c)
for c in pango_scatter_df.columns
],
opacity=alt.condition(clade_selection, alt.value(1), alt.value(0.35)),
size=alt.condition(clade_selection, alt.value(100), alt.value(55)),
strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0)),
)
.mark_circle(stroke="red", color="black")
.add_params(clade_selection)
.configure_axis(grid=False)
.properties(
title=alt.TitleParams(
[
"DMS predicted ACE2 affinity vs serum escape",
f"for Pango clades descended from {starting_clade}"
],
anchor="middle",
fontSize=14,
dy=-5,
),
)
.properties(width=300, height=300)
)
print(f"Saving chart to {pango_affinity_vs_escape_html}")
pango_scatter_chart.save(pango_affinity_vs_escape_html)
pango_scatter_chart
Saving chart to results/compare_natural/pango_affinity_vs_escape.html
Out[8]:
Correlate with clade growth¶
In [9]:
growth_rates = pd.read_csv(growth_rates_csv).rename(
columns={"pango": "clade", "seq_volume": "number sequences"}
)
if (invalid_clades := set(growth_rates["clade"]) - set(pango_clades)):
raise ValueError(f"Growth rates specified for {invalid_clades}")
pango_dms_growth_df = pango_dms_df.merge(growth_rates, on="clade", validate="many_to_one")
pango_dms_growth_df_rand = pango_dms_df_rand.merge(growth_rates, on="clade", validate="many_to_one")
print(
f"{growth_rates['clade'].nunique()} clades have growth rates estimates.\n"
f"{pango_dms_df['clade'].nunique()} clades have DMS estimates.\n"
f"{pango_dms_growth_df['clade'].nunique()} clades have growth and DMS estimates"
)
print("Simple correlations:")
display(
pango_dms_growth_df
.groupby("DMS_phenotype")
[["R", "phenotype", "phenotype_RBD_only", "phenotype_nonRBD_only"]]
.corr()
[["R"]]
)
990 clades have growth rates estimates. 1580 clades have DMS estimates. 923 clades have growth and DMS estimates Simple correlations:
| R | ||
|---|---|---|
| DMS_phenotype | ||
| ACE2 affinity | R | 1.000000 |
| phenotype | -0.488450 | |
| phenotype_RBD_only | -0.309932 | |
| phenotype_nonRBD_only | -0.331831 | |
| cell entry | R | 1.000000 |
| phenotype | 0.788332 | |
| phenotype_RBD_only | 0.812665 | |
| phenotype_nonRBD_only | 0.454013 | |
| sera escape | R | 1.000000 |
| phenotype | 0.930667 | |
| phenotype_RBD_only | 0.927965 | |
| phenotype_nonRBD_only | 0.359769 |
Plot number of sequences versus date, with sizes proportional to log of number of sequences in clade:
In [10]:
(
alt.Chart(pango_dms_growth_df)
.encode(
x="date",
y="R",
size=alt.Size("number sequences", scale=alt.Scale(type="log")),
tooltip=pango_dms_growth_df.columns.tolist(),
)
.mark_circle(opacity=0.25, color="black")
)
Out[10]:
Now perform OLS, weighting clades by log number of sequences:
In [11]:
# pivot DMS data to get phenotypes
def pivot_for_ols_vars(df):
ols_vars = (
df
.rename(
columns={
"phenotype": "full spike",
"phenotype_RBD_only": "RBD",
"phenotype_nonRBD_only": "non RBD",
}
)
.assign(
# group muts missing data from all phenotypes
muts_from_DMS_clade_missing_data=lambda x: (
x.groupby("clade")
[f"muts_from_{dms_clade}_missing_data"]
.transform(
lambda s: "; ".join(dict.fromkeys([m for ms in s.str.split("; ") for m in ms if m]))
)
),
)
.rename(columns={f"muts_from_{dms_clade}": "muts_from_DMS_clade"})
.pivot_table(
index=[
"clade",
"R",
"date",
"muts_from_DMS_clade",
"muts_from_DMS_clade_missing_data",
"number sequences",
],
columns="DMS_phenotype",
values=["full spike", "RBD", "non RBD"],
)
)
# flatten column names
assert all(len(c) == 2 for c in ols_vars.columns.values)
ols_vars.columns = [f"{pheno} ({domain})" for domain, pheno in ols_vars.columns.values]
return ols_vars.reset_index()
ols_vars = pivot_for_ols_vars(pango_dms_growth_df)
# https://www.einblick.ai/python-code-examples/ordinary-least-squares-regression-statsmodels/
for name, exog_vars, regression_chartfile, corr_chartfile in [
(
"full spike",
[f"{c} (full spike)" for c in phenotypes],
pango_dms_vs_growth_regression_html,
pango_dms_vs_growth_corr_html
),
(
"separate RBD and non-RBD",
[f"{c} ({d})" for d in ["RBD", "non RBD"] for c in phenotypes],
pango_dms_vs_growth_regression_by_domain_html,
pango_dms_vs_growth_corr_by_domain_html,
),
]:
print(f"\n\nFitting for {name}:")
ols_model = statsmodels.api.WLS(
endog=ols_vars[["R"]],
exog=statsmodels.api.add_constant(ols_vars[exog_vars]),
# weight by log n sequences, so pass log**2
weights=numpy.log(ols_vars["number sequences"])**2,
)
res_ols = ols_model.fit()
display(res_ols.summary())
fitted_df = ols_vars.assign(DMS_predicted_growth=res_ols.predict())
plot_size=180
clade_selection = alt.selection_point(fields=["clade"], on="mouseover", empty=False)
n_sequences_init = int(10 * math.log10(fitted_df["number sequences"].min())) / 10
n_sequences_slider = alt.param(
value=n_sequences_init,
bind=alt.binding_range(
name="minimum log10 number sequences in clade",
min=n_sequences_init,
max=math.log10(fitted_df["number sequences"].max() / 10),
),
)
# date slider: https://stackoverflow.com/a/67941109
select_date = alt.selection_interval(encodings=["x"])
date_slider = (
alt.Chart(fitted_df[["clade", "date"]].drop_duplicates())
.mark_bar(color="black")
.encode(
x=alt.X(
"date",
title="zoom bar to select clades by designation date",
axis=alt.Axis(format="%b-%Y"),
),
y=alt.Y("count()", title=["number", "clades"]),
)
.properties(width=1.5 * plot_size, height=45)
.add_params(select_date)
)
base_growth_chart = (
alt.Chart(fitted_df)
.transform_filter(
alt.expr.log(alt.datum["number sequences"]) / math.log(10) >= n_sequences_slider
)
.transform_filter(select_date)
.encode(
size=alt.Size(
"number sequences",
scale=alt.Scale(
type="log",
nice=False,
range=[15, 250],
),
legend=alt.Legend(symbolStrokeWidth=0, symbolFillColor="gray"),
),
strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0.5)),
strokeOpacity=alt.condition(clade_selection, alt.value(1), alt.value(0.5)),
tooltip=[
"clade",
alt.Tooltip("R", title="growth rate (R)", format=".1f"),
alt.Tooltip("DMS_predicted_growth", title="DMS predicted growth", format=".1f"),
alt.Tooltip("number sequences", format=".2g"),
alt.Tooltip("date", title="designation date"),
alt.Tooltip("muts_from_DMS_clade", title=f"muts from {dms_clade}"),
alt.Tooltip("muts_from_DMS_clade_missing_data", title="muts missing DMS data"),
*[alt.Tooltip(v, format=".2f") for v in exog_vars],
],
)
.properties(width=plot_size, height=plot_size)
.add_params(clade_selection, n_sequences_slider)
)
growth_charts = []
simple_corr_charts = []
for i, (dms_pheno, pheno) in enumerate(zip(
exog_vars,
itertools.cycle(phenotypes)
)):
assert dms_pheno.startswith(pheno)
base_pheno_chart = (
base_growth_chart
.encode(
y=alt.Y(
"R",
title="actual clade growth rate (R)",
scale=alt.Scale(nice=False, padding=5, zero=False),
axis=None if i % len(phenotypes) else alt.Axis(),
),
)
)
growth_charts.append(
base_pheno_chart
.encode(
x=alt.X(
"DMS_predicted_growth",
title="DMS predicted clade growth",
scale=alt.Scale(nice=False, padding=5, zero=False),
),
color=alt.Color(
dms_pheno,
title=None,
legend=alt.Legend(
orient="top",
titleFontSize=12,
gradientLength=plot_size,
gradientThickness=10,
offset=5,
tickCount=3,
),
scale=alt.Scale(
range=polyclonal.plot.color_gradient_hex("lightgray", phenotype_colors[pheno], 40),
nice=False,
),
),
)
.mark_circle(stroke="black", fillOpacity=0.6)
.properties(
title=alt.TitleParams(
text=dms_pheno,
subtitle=(
f"coefficient: {res_ols.params[dms_pheno]:.1f} "
# https://stackoverflow.com/a/53966201
+ f"\u00B1 {res_ols.bse[dms_pheno]:.1f}, "
+ f"P: {res_ols.pvalues[dms_pheno]:.1g}"
),
subtitleFontSize=11,
),
)
)
pheno_r, pheno_p = scipy.stats.pearsonr(fitted_df["R"], fitted_df[dms_pheno])
simple_corr_charts.append(
base_pheno_chart
.transform_calculate(color_phenotype=f"'{pheno}'")
.encode(
x=alt.X(
dms_pheno,
scale=alt.Scale(nice=False, padding=5, zero=False),
),
color=alt.Color(
"color_phenotype:N",
scale=alt.Scale(
range=list(phenotype_colors.values()),
domain=list(phenotype_colors.keys()),
),
legend=None,
),
)
.mark_circle(stroke="black", fillOpacity=0.3, color=phenotype_colors[pheno])
.properties(
title=alt.TitleParams(
text=dms_pheno,
subtitle=f"Pearson r: {pheno_r:.2f}",
subtitleFontSize=11,
),
)
)
actual_r = math.sqrt(res_ols.rsquared)
assert len(growth_charts) % len(phenotypes) == 0
growth_chart = (
alt.vconcat(
alt.vconcat(
*[
alt.hconcat(
*growth_charts[i * len(phenotypes): (i + 1) * len(phenotypes)], spacing=13
).resolve_scale(color="independent")
for i in range(len(growth_charts) // len(phenotypes))
],
spacing=13,
),
date_slider,
)
.properties(
title=alt.TitleParams(
f"Weighted linear regression of DMS phenotypes vs clade growth (Pearson r = {actual_r:.2f})",
anchor="middle",
fontSize=14,
dy=-5,
),
)
.configure_axis(grid=False)
)
simple_corr_chart = (
alt.vconcat(
alt.vconcat(
*[
alt.hconcat(
*simple_corr_charts[i * len(phenotypes): (i + 1) * len(phenotypes)], spacing=13
)
for i in range(len(simple_corr_charts) // len(phenotypes))
],
spacing=13,
),
date_slider,
)
.properties(
title=alt.TitleParams(
"Simple correlations of DMS phenotypes vs clade growth",
anchor="middle",
fontSize=14,
dy=-5,
),
)
.configure_axis(grid=False)
)
display(growth_chart)
print(f"Saving to {regression_chartfile}")
growth_chart.save(regression_chartfile)
display(simple_corr_chart)
print(f"Saving to {corr_chartfile}")
simple_corr_chart.save(corr_chartfile)
# fit randomized models and compute P-value based on R values
print("Computing P-value from randomizations")
rand_r = []
for randomseed, rand_df in pango_dms_growth_df_rand.groupby("randomize"):
rand_ols_vars = pivot_for_ols_vars(rand_df)
rand_ols_model = statsmodels.api.WLS(
endog=rand_ols_vars[["R"]],
exog=statsmodels.api.add_constant(rand_ols_vars[exog_vars]),
# weight by log n sequences, so pass log**2
weights=numpy.log(rand_ols_vars["number sequences"])**2,
)
rand_res_ols = rand_ols_model.fit()
rand_r.append(math.sqrt(rand_res_ols.rsquared))
n_rand_ge = sum(r >= actual_r for r in rand_r)
pval = f"= {n_rand_ge / len(rand_r)}" if n_rand_ge else f"< {1 / len(rand_r)}"
rand_r_hist = (
alt.Chart(pd.DataFrame({"r": rand_r}))
.encode(
x=alt.X(
"r",
title="Pearson r",
bin=alt.BinParams(step=0.02, extent=(0, 1)),
scale=alt.Scale(domain=(0, 1)),
axis=alt.Axis(values=[0, 0.2, 0.4, 0.6, 0.8, 1]),
),
y=alt.Y("count()", title="number of randomizations"),
)
.mark_bar(color="black", opacity=0.65, align="right")
.properties(width=250, height=130)
)
actual_r_line = (
alt.Chart(pd.DataFrame({"r": [actual_r]}))
.encode(x="r")
.mark_rule(size=2, color="red", strokeDash=[4, 2])
)
pval_chart = (
(rand_r_hist + actual_r_line)
.configure_axis(grid=False)
.properties(
title=alt.TitleParams(
f"P {pval}",
subtitle=f"{n_rand_ge} of {len(rand_r)} randomizations \u2265 actual r of {actual_r:.2f}",
),
)
)
display(pval_chart)
Fitting for full spike:
| Dep. Variable: | R | R-squared: | 0.883 |
|---|---|---|---|
| Model: | WLS | Adj. R-squared: | 0.883 |
| Method: | Least Squares | F-statistic: | 2309. |
| Date: | Sun, 08 Oct 2023 | Prob (F-statistic): | 0.00 |
| Time: | 13:06:37 | Log-Likelihood: | -3613.1 |
| No. Observations: | 923 | AIC: | 7234. |
| Df Residuals: | 919 | BIC: | 7254. |
| Df Model: | 3 | ||
| Covariance Type: | nonrobust |
| coef | std err | t | P>|t| | [0.025 | 0.975] | |
|---|---|---|---|---|---|---|
| const | 32.6462 | 0.730 | 44.692 | 0.000 | 31.213 | 34.080 |
| sera escape (full spike) | 24.3328 | 0.570 | 42.684 | 0.000 | 23.214 | 25.452 |
| ACE2 affinity (full spike) | 4.1776 | 1.260 | 3.314 | 0.001 | 1.704 | 6.651 |
| cell entry (full spike) | 13.4457 | 2.225 | 6.043 | 0.000 | 9.079 | 17.812 |
| Omnibus: | 27.151 | Durbin-Watson: | 0.816 |
|---|---|---|---|
| Prob(Omnibus): | 0.000 | Jarque-Bera (JB): | 35.388 |
| Skew: | 0.315 | Prob(JB): | 2.07e-08 |
| Kurtosis: | 3.724 | Cond. No. | 12.5 |
Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Saving to results/compare_natural/pango_dms_vs_growth_regression.html
Saving to results/compare_natural/pango_dms_vs_growth_corr.html Computing P-value from randomizations
Fitting for separate RBD and non-RBD:
| Dep. Variable: | R | R-squared: | 0.891 |
|---|---|---|---|
| Model: | WLS | Adj. R-squared: | 0.890 |
| Method: | Least Squares | F-statistic: | 1243. |
| Date: | Sun, 08 Oct 2023 | Prob (F-statistic): | 0.00 |
| Time: | 13:07:13 | Log-Likelihood: | -3581.5 |
| No. Observations: | 923 | AIC: | 7177. |
| Df Residuals: | 916 | BIC: | 7211. |
| Df Model: | 6 | ||
| Covariance Type: | nonrobust |
| coef | std err | t | P>|t| | [0.025 | 0.975] | |
|---|---|---|---|---|---|---|
| const | 33.5352 | 0.731 | 45.883 | 0.000 | 32.101 | 34.970 |
| sera escape (RBD) | 29.3391 | 0.848 | 34.608 | 0.000 | 27.675 | 31.003 |
| ACE2 affinity (RBD) | 3.9192 | 1.357 | 2.888 | 0.004 | 1.256 | 6.582 |
| cell entry (RBD) | -18.1338 | 4.472 | -4.055 | 0.000 | -26.909 | -9.358 |
| sera escape (non RBD) | 40.3985 | 4.835 | 8.355 | 0.000 | 30.909 | 49.888 |
| ACE2 affinity (non RBD) | 9.9229 | 1.996 | 4.972 | 0.000 | 6.006 | 13.840 |
| cell entry (non RBD) | 23.1634 | 2.956 | 7.836 | 0.000 | 17.362 | 28.965 |
| Omnibus: | 44.555 | Durbin-Watson: | 0.908 |
|---|---|---|---|
| Prob(Omnibus): | 0.000 | Jarque-Bera (JB): | 92.936 |
| Skew: | 0.297 | Prob(JB): | 6.59e-21 |
| Kurtosis: | 4.436 | Cond. No. | 27.5 |
Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Saving to results/compare_natural/pango_dms_vs_growth_regression_by_domain.html
Saving to results/compare_natural/pango_dms_vs_growth_corr_by_domain.html Computing P-value from randomizations
Distributions of DMS mutation effects in clades with growth estimates versus all mutations¶
In [12]:
muts_in_clades = collections.Counter(
pango_dms_growth_df
[f"muts_from_{dms_clade}"]
.pipe(lambda s: s[s != ""])
.str.split("; ")
.explode()
)
print(f"There are {len(muts_in_clades)} mutations found in clades with growth estimates")
all_muts_dms = (
dms_summary
.query("wildtype != mutant")
.assign(mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"])
.assign(region=lambda x: x["region"].where(x["region"] == "RBD", "non RBD"))
.melt(
id_vars=["mutation", "region"],
value_vars=phenotypes,
var_name="DMS_phenotype",
value_name="phenotype",
)
.query("phenotype.notnull()")
)
all_muts_dms = pd.concat(
[
all_muts_dms.assign(mutation_type="any", count=1),
all_muts_dms.query("mutation in @muts_in_clades").assign(
mutation_type="in Pango clade",
count=lambda x: x["mutation"].map(muts_in_clades),
),
]
)
for pheno in phenotypes:
base_hist = (
alt.Chart(
all_muts_dms
.query("DMS_phenotype == @pheno")
.drop(columns=["DMS_phenotype", "mutation"])
)
.encode(
x=alt.X("phenotype", bin=alt.BinParams(maxbins=50)),
y=alt.Y("sum(count)", title="mutations"),
color=alt.value(phenotype_colors[pheno]),
row=alt.Row("mutation_type", title=None, spacing=5),
)
.properties(width=200, height=75, title=pheno)
.mark_bar()
.resolve_scale(y="independent")
)
display(base_hist)
There are 278 mutations found in clades with growth estimates